#!/usr/bin/env python
# -*- coding: utf-8 -*-

"""
LLaDA GUI - A graphical interface for interacting with the LLaDA language model.
"""

import os
import sys
import torch
from PyQt6.QtWidgets import (
    QApplication, QMainWindow, QWidget, QVBoxLayout, QHBoxLayout, 
    QTextEdit, QPushButton, QLabel, QSpinBox, QComboBox, QGroupBox,
    QCheckBox, QProgressBar, QSplitter, QMessageBox, QGridLayout,
    QScrollArea, QDoubleSpinBox, QTabWidget, QRadioButton, QButtonGroup
)
from PyQt6.QtCore import Qt, QThread, pyqtSignal, QTimer
from PyQt6.QtGui import QFont, QTextCursor

# Import our modules
from config import WINDOW_TITLE, WINDOW_WIDTH, WINDOW_HEIGHT, SPLITTER_RATIO, DEFAULT_PARAMS
from memory_monitor import MemoryMonitor
from llada_worker import LLaDAWorker
from diffusion_visualization import DiffusionProcessVisualizer
from utils import format_memory_info

class LLaDAGUI(QMainWindow):
    """Main GUI for the LLaDA application."""
    
    def __init__(self):
        super().__init__()
        self.setWindowTitle(WINDOW_TITLE)
        self.resize(WINDOW_WIDTH, WINDOW_HEIGHT)
        
        # Set up memory monitor
        self.memory_monitor = MemoryMonitor()
        self.memory_monitor.update.connect(self.update_memory_info)
        
        # Worker thread reference
        self.worker = None
        
        # Initialize UI
        self.init_ui()
        
        # Display welcome message
        self.setup_welcome_message()
        
        # Start memory monitoring
        self.memory_monitor.start()
    
    def closeEvent(self, event):
        """Properly clean up when the window is closed."""
        # Stop memory monitoring
        self.memory_monitor.stop()
        
        # Stop worker thread if running
        if self.worker and self.worker.isRunning():
            self.worker.stop()
            self.worker.wait(1000)  # Wait for thread to finish with timeout
            
            # If thread is still running, terminate it
            if self.worker.isRunning():
                self.worker.terminate()
        
        # Clean up GPU memory
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Accept the close event
        event.accept()

    def init_ui(self):
        """Initialize the user interface."""
        # Main widget and layout
        main_widget = QWidget()
        main_layout = QVBoxLayout(main_widget)
        
        # Create a splitter for flexible layout
        splitter = QSplitter(Qt.Orientation.Vertical)
        main_layout.addWidget(splitter)
        
        # Input area (top section)
        input_widget = QWidget()
        input_layout = QVBoxLayout(input_widget)
        
        # Memory usage display
        self.memory_group = QGroupBox("System Resources")
        memory_layout = QGridLayout(self.memory_group)
        
        # System memory
        memory_layout.addWidget(QLabel("System RAM:"), 0, 0)
        self.system_memory_label = QLabel("- / - GB (-%)")
        memory_layout.addWidget(self.system_memory_label, 0, 1)
        
        # GPU memory
        memory_layout.addWidget(QLabel("GPU Memory:"), 0, 2)
        self.gpu_memory_label = QLabel("- / - GB (-%)")
        memory_layout.addWidget(self.gpu_memory_label, 0, 3)
        
        input_layout.addWidget(self.memory_group)
        
        # Input prompt area
        input_label = QLabel("Enter your prompt:")
        input_label.setFont(QFont("Arial", 10, QFont.Weight.Bold))
        input_layout.addWidget(input_label)
        
        self.input_text = QTextEdit()
        self.input_text.setMinimumHeight(100)
        self.input_text.setPlaceholderText("Type your prompt here...")
        input_layout.addWidget(self.input_text)
        
        # Parameters area
        params_group = QGroupBox("Generation Parameters")
        params_layout = QGridLayout(params_group)
        
        # Parameter controls
        self.gen_length_spin = QSpinBox()
        self.gen_length_spin.setRange(16, 512)
        self.gen_length_spin.setValue(DEFAULT_PARAMS['gen_length'])
        self.gen_length_spin.setSingleStep(16)
        params_layout.addWidget(QLabel("Generation Length:"), 0, 0)
        params_layout.addWidget(self.gen_length_spin, 0, 1)
        
        self.steps_spin = QSpinBox()
        self.steps_spin.setRange(16, 512)
        self.steps_spin.setValue(DEFAULT_PARAMS['steps'])
        self.steps_spin.setSingleStep(16)
        params_layout.addWidget(QLabel("Sampling Steps:"), 0, 2)
        params_layout.addWidget(self.steps_spin, 0, 3)
        
        self.block_length_spin = QSpinBox()
        self.block_length_spin.setRange(16, 256)
        self.block_length_spin.setValue(DEFAULT_PARAMS['block_length'])
        self.block_length_spin.setSingleStep(16)
        params_layout.addWidget(QLabel("Block Length:"), 1, 0)
        params_layout.addWidget(self.block_length_spin, 1, 1)
        
        self.temperature_spin = QDoubleSpinBox()
        self.temperature_spin.setRange(0, 2)
        self.temperature_spin.setValue(DEFAULT_PARAMS['temperature'])
        self.temperature_spin.setSingleStep(0.1)
        params_layout.addWidget(QLabel("Temperature:"), 1, 2)
        params_layout.addWidget(self.temperature_spin, 1, 3)
        
        self.cfg_scale_spin = QDoubleSpinBox()
        self.cfg_scale_spin.setRange(0, 5)
        self.cfg_scale_spin.setValue(DEFAULT_PARAMS['cfg_scale'])
        self.cfg_scale_spin.setSingleStep(0.1)
        params_layout.addWidget(QLabel("CFG Scale:"), 2, 0)
        params_layout.addWidget(self.cfg_scale_spin, 2, 1)
        
        self.remasking_combo = QComboBox()
        self.remasking_combo.addItems(["low_confidence", "random"])
        self.remasking_combo.setCurrentText(DEFAULT_PARAMS['remasking'])
        params_layout.addWidget(QLabel("Remasking Strategy:"), 2, 2)
        params_layout.addWidget(self.remasking_combo, 2, 3)
        
        # Hardware options
        device_group = QGroupBox("Hardware & Memory Options")
        device_layout = QGridLayout(device_group)
        
        # Device selection
        device_layout.addWidget(QLabel("Device:"), 0, 0)
        self.device_group = QButtonGroup()
        
        self.cpu_radio = QRadioButton("CPU")
        self.gpu_radio = QRadioButton("GPU (CUDA)")
        
        # Set default based on availability
        if torch.cuda.is_available():
            self.gpu_radio.setChecked(True)
        else:
            self.cpu_radio.setChecked(True)
            self.gpu_radio.setEnabled(False)
        
        self.device_group.addButton(self.cpu_radio, 0)
        self.device_group.addButton(self.gpu_radio, 1)
        
        device_layout.addWidget(self.cpu_radio, 0, 1)
        device_layout.addWidget(self.gpu_radio, 0, 2)
        
        # Memory optimization options
        device_layout.addWidget(QLabel("Memory Optimization:"), 1, 0)
        
        self.use_normal = QRadioButton("Normal Precision")
        self.use_8bit = QRadioButton("8-bit Quantization")
        self.use_4bit = QRadioButton("4-bit Quantization")
        
        self.precision_group = QButtonGroup()
        self.precision_group.addButton(self.use_normal, 0)
        self.precision_group.addButton(self.use_8bit, 1)
        self.precision_group.addButton(self.use_4bit, 2)
        
        # Set default based on available memory
        self.use_8bit.setChecked(True)  # Default to 8-bit for safety
        
        device_layout.addWidget(self.use_normal, 1, 1)
        device_layout.addWidget(self.use_8bit, 1, 2)
        device_layout.addWidget(self.use_4bit, 1, 3)
        
        # Connect device selection changes to update quantization options
        self.cpu_radio.toggled.connect(self.update_quantization_options)
        self.gpu_radio.toggled.connect(self.update_quantization_options)
        
        # Add memory info and parameter groups
        params_layout.addWidget(device_group, 3, 0, 1, 4)
        
        # Add parameters group to input layout
        input_layout.addWidget(params_group)
        
        # Buttons
        button_layout = QHBoxLayout()
        
        self.generate_btn = QPushButton("Generate")
        self.generate_btn.setFont(QFont("Arial", 10, QFont.Weight.Bold))
        self.generate_btn.clicked.connect(self.start_generation)
        button_layout.addWidget(self.generate_btn)
        
        self.stop_btn = QPushButton("Stop")
        self.stop_btn.setEnabled(False)
        self.stop_btn.clicked.connect(self.stop_generation)
        button_layout.addWidget(self.stop_btn)
        
        self.clear_btn = QPushButton("Clear")
        self.clear_btn.clicked.connect(self.clear_output)
        button_layout.addWidget(self.clear_btn)
        
        input_layout.addLayout(button_layout)
        
        # Add input widget to splitter
        splitter.addWidget(input_widget)
        
        # Output area (bottom section) with tabs
        output_widget = QWidget()
        output_layout = QVBoxLayout(output_widget)
        
        output_label = QLabel("Generated Output:")
        output_label.setFont(QFont("Arial", 10, QFont.Weight.Bold))
        output_layout.addWidget(output_label)
        
        # Progress bar and status
        self.progress_bar = QProgressBar()
        self.progress_bar.setVisible(False)
        output_layout.addWidget(self.progress_bar)
        
        self.status_label = QLabel()
        output_layout.addWidget(self.status_label)
        
        # Create a tab widget for showing output and visualization
        self.tab_widget = QTabWidget()
        
        # Text output tab
        text_tab = QWidget()
        text_layout = QVBoxLayout(text_tab)
        
        self.output_text = QTextEdit()
        self.output_text.setReadOnly(True)
        self.output_text.setPlaceholderText("Generated text will appear here...")
        text_layout.addWidget(self.output_text)
        
        self.tab_widget.addTab(text_tab, "Text Output")
        
        # Visualization tab
        viz_tab = QWidget()
        viz_layout = QVBoxLayout(viz_tab)
        
        self.diffusion_viz = DiffusionProcessVisualizer()
        viz_layout.addWidget(self.diffusion_viz)
        
        self.tab_widget.addTab(viz_tab, "Diffusion Visualization")
        
        # Add tabs to layout
        output_layout.addWidget(self.tab_widget)
        
        # Add output widget to splitter
        splitter.addWidget(output_widget)
        
        # Set the splitter sizes
        splitter.setSizes(SPLITTER_RATIO)
        
        # Set the central widget
        self.setCentralWidget(main_widget)
    
    def update_memory_info(self, stats):
        """Update the memory info display with current stats."""
        # Format memory statistics
        system_text, gpu_text = format_memory_info(stats)
        
        # Update labels
        self.system_memory_label.setText(system_text)
        self.gpu_memory_label.setText(gpu_text)
        
        # Apply warning color if memory usage is high
        if stats.get('gpu_available', False):
            if stats['gpu_percent'] > 90:
                self.gpu_memory_label.setStyleSheet("color: red; font-weight: bold")
            elif stats['gpu_percent'] > 75:
                self.gpu_memory_label.setStyleSheet("color: orange")
            else:
                self.gpu_memory_label.setStyleSheet("")
    
    def update_quantization_options(self):
        """Update quantization options based on device selection."""
        if self.cpu_radio.isChecked():
            # Disable quantization options for CPU
            self.use_normal.setChecked(True)
            self.use_8bit.setEnabled(False)
            self.use_4bit.setEnabled(False)
        else:
            # Enable quantization options for GPU
            self.use_8bit.setEnabled(True)
            self.use_4bit.setEnabled(True)
    
    def setup_welcome_message(self):
        """Display a welcome message with model information."""
        welcome_msg = """
<h2>Welcome to LLaDA GUI</h2>

<p>This application provides a graphical interface to the LLaDA (Large Language Diffusion with mAsking) model, 
an 8B scale diffusion model for language generation.</p>

<h3>Key Features:</h3>
<ul>
<li>Diffusion-based text generation using masking</li>
<li>Model rivals LLaMA3 8B in performance</li>
<li>Configurable generation parameters</li>
<li>Visualization of the diffusion process</li>
<li>Memory-optimized operation with 4-bit and 8-bit quantization</li>
<li>CPU fallback for low-memory situations</li>
</ul>

<h3>Memory-Safe Operation:</h3>
<p>This GUI includes several features to prevent out-of-memory errors:</p>
<ul>
<li>Real-time memory monitoring for both system RAM and GPU</li>
<li>Automatic parameter adjustment based on available memory</li>
<li>Optional quantization (4-bit or 8-bit) for reduced memory usage</li>
<li>CPU fallback mode when GPU memory is insufficient</li>
</ul>

<h3>Usage:</h3>
<p>1. Enter your prompt in the input box<br>
2. Adjust generation parameters as needed<br>
3. Select hardware and memory options<br>
4. Click "Generate" to start the diffusion process<br>
5. View the generated output in the "Text Output" tab<br>
6. Switch to the "Diffusion Visualization" tab to see the generation process</p>

<h3>Optimizing for Your Hardware:</h3>
<p>If you experience out-of-memory errors:</p>
<ul>
<li>Reduce generation length and steps</li>
<li>Try 8-bit or 4-bit quantization</li>
<li>Switch to CPU mode for most reliable (but slower) operation</li>
</ul>

<p>Ready to start? Enter a prompt above and click "Generate"!</p>
"""
        self.output_text.setHtml(welcome_msg)
    
    def get_generation_config(self):
        """Get the current generation configuration from UI elements."""
        # Determine device based on selection
        device = 'cuda' if self.gpu_radio.isChecked() and torch.cuda.is_available() else 'cpu'
        
        # Determine quantization settings
        use_8bit = self.use_8bit.isChecked() and device == 'cuda'
        use_4bit = self.use_4bit.isChecked() and device == 'cuda'
        
        return {
            'gen_length': self.gen_length_spin.value(),
            'steps': self.steps_spin.value(),
            'block_length': self.block_length_spin.value(),
            'temperature': self.temperature_spin.value(),
            'cfg_scale': self.cfg_scale_spin.value(),
            'remasking': self.remasking_combo.currentText(),
            'device': device,
            'use_8bit': use_8bit,
            'use_4bit': use_4bit
        }
    
    def start_generation(self):
        """Start the generation process."""
        prompt = self.input_text.toPlainText().strip()
        
        if not prompt:
            QMessageBox.warning(self, "Empty Prompt", "Please enter a prompt before generating.")
            return
        
        # Get configuration from UI
        config = self.get_generation_config()
        
        # Disable input controls during generation
        self.set_controls_enabled(False)
        
        # Setup progress bar
        self.progress_bar.setValue(0)
        self.progress_bar.setVisible(True)
        self.status_label.setText("Initializing...")
        
        # Clear previous output
        self.output_text.clear()
        
        # Setup visualization for the diffusion process
        self.diffusion_viz.setup_process(config['gen_length'], config['steps'])
        
        # Create and start worker thread
        self.worker = LLaDAWorker(prompt, config)
        self.worker.progress.connect(self.update_progress)
        self.worker.step_update.connect(self.update_visualization)
        self.worker.finished.connect(self.generation_finished)
        self.worker.error.connect(self.generation_error)
        self.worker.memory_warning.connect(self.display_memory_warning)
        self.worker.start()
        
        # Enable stop button
        self.stop_btn.setEnabled(True)
        
        # Switch to the visualization tab
        self.tab_widget.setCurrentIndex(1)
    
    def stop_generation(self):
        """Stop the current generation process."""
        if self.worker and self.worker.isRunning():
            self.worker.stop()
            self.status_label.setText("Stopping generation...")
            self.stop_btn.setEnabled(False)
    
    def display_memory_warning(self, warning_msg):
        """Display a memory warning to the user."""
        QMessageBox.warning(self, "Memory Warning", warning_msg)
    
    def update_progress(self, progress, status, data):
        """Update the progress bar and status."""
        self.progress_bar.setValue(progress)
        self.status_label.setText(status)
        
        # Update partial output if available
        if 'partial_output' in data and data['partial_output']:
            self.output_text.setText(data['partial_output'])
            cursor = self.output_text.textCursor()
            cursor.movePosition(QTextCursor.MoveOperation.End)
            self.output_text.setTextCursor(cursor)
    
    def update_visualization(self, step, tokens, masks, confidences):
        """Update the diffusion visualization."""
        self.diffusion_viz.update_process(step, tokens, masks, confidences)
    
    def generation_finished(self, result):
        """Handle generation completion."""
        self.output_text.setText(result)
        self.progress_bar.setValue(100)
        self.status_label.setText("Generation complete")
        self.set_controls_enabled(True)
        
        # Switch to the text output tab
        self.tab_widget.setCurrentIndex(0)
        
        # Hide progress bar after a delay
        QTimer.singleShot(3000, lambda: self.progress_bar.setVisible(False))
    
    def generation_error(self, error_msg):
        """Handle generation errors."""
        self.output_text.setText(f"<span style='color:red'>Error during generation:</span><br><pre>{error_msg}</pre>")
        self.progress_bar.setVisible(False)
        self.status_label.setText("Generation failed")
        self.set_controls_enabled(True)
        
        if "CUDA out of memory" in error_msg:
            QMessageBox.critical(
                self, 
                "Memory Error", 
                "CUDA ran out of memory. Please try:\n\n"
                "1. Reducing generation length\n"
                "2. Reducing sampling steps\n"
                "3. Using 8-bit or 4-bit quantization\n"
                "4. Switching to CPU mode"
            )
        else:
            QMessageBox.critical(self, "Generation Error", f"An error occurred during generation:\n\n{error_msg}")
    
    def set_controls_enabled(self, enabled):
        """Enable or disable controls during generation."""
        # Enable/disable normal controls
        self.generate_btn.setEnabled(enabled)
        self.input_text.setEnabled(enabled)
        self.clear_btn.setEnabled(enabled)
        self.gen_length_spin.setEnabled(enabled)
        self.steps_spin.setEnabled(enabled)
        self.block_length_spin.setEnabled(enabled)
        self.temperature_spin.setEnabled(enabled)
        self.cfg_scale_spin.setEnabled(enabled)
        self.remasking_combo.setEnabled(enabled)
        self.cpu_radio.setEnabled(enabled)
        self.gpu_radio.setEnabled(enabled and torch.cuda.is_available())
        self.use_normal.setEnabled(enabled)
        self.use_8bit.setEnabled(enabled and self.gpu_radio.isChecked())
        self.use_4bit.setEnabled(enabled and self.gpu_radio.isChecked())
        
        # Update stop button - safely check if worker is running
        is_worker_running = False
        if self.worker is not None:
            try:
                is_worker_running = self.worker.isRunning()
            except:
                is_worker_running = False
        
        # Only enable stop button when worker is running
        self.stop_btn.setEnabled(not enabled and is_worker_running)
        
        # Update button text
        self.generate_btn.setText("Generate" if enabled else "Generating...")
    
    def clear_output(self):
        """Clear the output text area."""
        self.output_text.clear()


def main():
    """Main application entry point."""
    app = QApplication(sys.argv)
    window = LLaDAGUI()
    window.show()
    sys.exit(app.exec())


if __name__ == "__main__":
    main()
